import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np

import os
import random

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

torch.set_default_tensor_type(torch.DoubleTensor)

class NormFaceLoss(nn.Module):
    def __init__(self, temp=10, weight=None):
        super(NormFaceLoss, self).__init__()
        self.temp = temp
        self.weight = weight

    def forward(self, logits, labels):
        logits = logits * self.temp
        max_logits = torch.max(logits, dim=1, keepdim=True)[0]
        return F.cross_entropy(logits - max_logits, labels, weight=self.weight)


class SampleMarginLoss(nn.Module):
    def __init__(self, weight=None):
        super(SampleMarginLoss, self).__init__()
        self.weight = weight

    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=1)
        tmp = logits * (1-label_one_hot) - label_one_hot
        l2 = torch.max(tmp, dim=1)[0]
        loss = l2 - l1
        if self.weight is not None:
            weight = self.weight.gather(0, labels.view(-1))
            loss = loss * weight
        return loss.mean()


def generate_weight(n_classes, n_hiddens, use_relu=False):

    n_classes = n_classes
    n_hiddens = n_hiddens
    n_samples = n_classes * 1
    epochs = 400000

    n_per_class = n_samples // n_classes

    labels = []
    for i in range(n_classes):
        labels += [i] * n_per_class
    np.random.shuffle(labels)
    labels = torch.LongTensor(labels).to(device)

    def get_margin(weight):
        tmp = F.normalize(weight, dim=1)
        sim = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
        sim = torch.clamp(sim, -1+1e-7, 1-1e-7)
        return torch.acos(torch.max(sim)).item() / np.pi * 180

    def evaluate(out, labels):
        probs = F.softmax(out, dim=1)
        pred = torch.argmax(probs, 1)
        total = labels.size(0)
        correct = (pred==labels).sum().item()

        acc = float(correct) / float(total)
        return acc

    Z = torch.randn(n_samples, n_hiddens).to(device)
    Z.requires_grad = True
    W = torch.randn(n_classes, n_hiddens).to(device)
    W.requires_grad = True
    nn.init.kaiming_normal_(W)

    optimizer = torch.optim.SGD([Z, W], lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=0)

    criterion = NormFaceLoss(temp=5)
    sample_margin = SampleMarginLoss()
    score = 0
    best_weight = None
    for i in range(epochs):
        if use_relu:
            z = F.relu(Z)
        else:
            z = Z
        w = W
        L2_z = F.normalize(z, dim=1)
        L2_w = F.normalize(w, dim=1)
        out = F.linear(L2_z, L2_w)
        loss = criterion(out, labels)
        sm = - sample_margin(out, labels).item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        margin = get_margin(W)
        if score < margin * sm:
            best_weight = W.cpu().detach()
            score = margin * sm

        if i % 500 == 0:
            acc = evaluate(out, labels)
            print('Iter {}: loss={:.4f}, acc={:.4f}, SaMargin={:.4f}, ClsMargin={:.4f}'.format(i, loss.item(), acc, sm, margin))


    weight = best_weight
    weight = F.normalize(weight, dim=1)
    weight = weight.cpu().numpy()
    os.makedirs('./prototypes', exist_ok=True)
    path = './prototypes/weight1_' if use_relu else './prototypes/weight0_'
    np.save(path + str(n_classes) + 'x' + str(n_hiddens) + '.npy', weight)


if __name__ == '__main__':
    generate_weight(1000, 2048, use_relu=False)